from scripts.datasets.clipping import TextDatasetClippingCondition, TextDatasetClippingHparams, \
    make_text_dataset_clipping_transform
from scripts.datasets.presets import load_preset_text_dataset


def browse_text_dataset(
    text_dataset_name: str,
    text_dataset_split: str | None,
    text_dataset_clipping_hparams: TextDatasetClippingHparams | None = None,
):
    dataset = load_preset_text_dataset(text_dataset_name, text_dataset_split)
    transform = make_text_dataset_clipping_transform(text_dataset_clipping_hparams)
    dataset = transform(dataset)
    for text in dataset:
        print(text)
        input()


if __name__ == '__main__':
    browse_text_dataset(
        text_dataset_name="wikien",
        text_dataset_split="train",
        text_dataset_clipping_hparams=TextDatasetClippingHparams(
            multiple=False,
            conditions=[
                TextDatasetClippingCondition(text_tokenizer_name="llama3", max_tokens_n=192),
                TextDatasetClippingCondition(text_tokenizer_name="utf8", max_tokens_n=1024)]))
